{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference an audio label by using the model trained on kaggle dataset\n",
    "## Test audios from both kaggle dataset, as well as my own.\n",
    "## The result is super good !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "ROOT = \"../\"\n",
    "sys.path.append(ROOT)\n",
    "\n",
    "import numpy as np \n",
    "import torch \n",
    "\n",
    "if 1: # my lib\n",
    "    import utils.lib_io as lib_io\n",
    "    import utils.lib_commons as lib_commons\n",
    "    import utils.lib_datasets as lib_datasets\n",
    "    import utils.lib_augment as lib_augment\n",
    "    import utils.lib_ml as lib_ml\n",
    "    import utils.lib_rnn as lib_rnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load weights from  ../checkpoints/kaggle_accu_914/model_025.ckpt\n"
     ]
    }
   ],
   "source": [
    "# Create model\n",
    "\n",
    "args = lib_rnn.set_default_args()\n",
    "args.classes_txt = \"../config/classes_kaggle.names\" \n",
    "args.load_weights_from=\"../checkpoints/kaggle_accu_914/model_025.ckpt\"\n",
    "\n",
    "model = lib_rnn.create_RNN_model(args, args.load_weights_from)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['bird', 'happy', 'stop', 'zero', 'six', 'learn', 'off', 'no', 'sheila', 'forward', 'eight', 'one', 'seven', 'backward', 'tree', 'bed', 'wow', 'right', 'dog', 'follow', 'up', 'down', 'nine', 'two', 'cat', 'yes', 'on', 'visual', 'left', 'five', 'four', 'marvin', 'go', 'house', 'three']\n"
     ]
    }
   ],
   "source": [
    "# Load labels\n",
    "\n",
    "classes = lib_io.read_list(args.classes_txt)\n",
    "print(classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function for randomly reading an audio \n",
    "def get_a_random_wav(audio_folder):\n",
    "    names = lib_commons.get_filenames(audio_folder, \"*/*.wav\")\n",
    "    idx = np.random.randint(len(names))\n",
    "    name = names[idx]\n",
    "    audio = lib_datasets.AudioClass(filename=name)\n",
    "    return audio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Predict audio label\n",
    "def predict(x):\n",
    "    idx = model.predict(x)\n",
    "    return classes[idx]\n",
    "\n",
    "def predict_audio(audio):\n",
    "    if audio.mfcc is None:\n",
    "        audio.compute_mfcc()\n",
    "    x = audio.mfcc.T\n",
    "    predicted_label = predict(x)\n",
    "    true_label = audio.filename.split('/')[-2]\n",
    "    print(f\"Label: true = {true_label}, predict = {predicted_label}\")\n",
    "    return predicted_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label: true = no, predict = no\n"
     ]
    }
   ],
   "source": [
    "# Test on kaggle dataset\n",
    "\n",
    "audio_folder = \"../data/kaggle/\"\n",
    "audio = get_a_random_wav(audio_folder)\n",
    "predicted_label = predict_audio(audio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label: true = right, predict = right\n"
     ]
    }
   ],
   "source": [
    "# Test on my own data\n",
    "\n",
    "audio_folder = \"../data/data_train/\"\n",
    "audio = get_a_random_wav(audio_folder)\n",
    "predicted_label = predict_audio(audio)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
